{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import time\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import OrderedDict\n",
    "from torch.autograd import Variable\n",
    "from pathlib import Path\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pix2pixhd_dir = Path('../src/pix2pixHD/')\n",
    "\n",
    "import sys\n",
    "sys.path.append(str(pix2pixhd_dir))\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from options.train_options import TrainOptions\n",
    "from data.data_loader import CreateDataLoader\n",
    "from models.models import create_model\n",
    "import util.util as util\n",
    "from util.visualizer import Visualizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../data/train_opt.pkl', mode='rb') as f:\n",
    "    opt = pickle.load(f)\n",
    "    \n",
    "iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loader = CreateDataLoader(opt)\n",
    "dataset = data_loader.load_data()\n",
    "dataset_size = len(data_loader)\n",
    "print('#training images = %d' % dataset_size)\n",
    "    \n",
    "start_epoch, epoch_iter = 1, 0\n",
    "total_steps = (start_epoch-1) * dataset_size + epoch_iter\n",
    "display_delta = total_steps % opt.display_freq\n",
    "print_delta = total_steps % opt.print_freq\n",
    "save_delta = total_steps % opt.save_latest_freq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = create_model(opt)\n",
    "visualizer = Visualizer(opt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):\n",
    "    epoch_start_time = time.time()\n",
    "    if epoch != start_epoch:\n",
    "        epoch_iter = epoch_iter % dataset_size\n",
    "    for i, data in enumerate(dataset, start=epoch_iter):\n",
    "        iter_start_time = time.time()\n",
    "        total_steps += opt.batchSize\n",
    "        epoch_iter += opt.batchSize\n",
    "\n",
    "        # whether to collect output images\n",
    "        save_fake = total_steps % opt.display_freq == display_delta\n",
    "        \n",
    "        ############## Forward Pass ######################\n",
    "        losses, generated = model(Variable(data['label']), Variable(data['inst']), \n",
    "            Variable(data['image']), Variable(data['feat']), infer=save_fake)\n",
    "        \n",
    "        # sum per device losses\n",
    "        losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]\n",
    "        loss_dict = dict(zip(model.module.loss_names, losses))\n",
    "\n",
    "        # calculate final loss scalar\n",
    "        loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5\n",
    "        loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0)\n",
    "        \n",
    "        ############### Backward Pass ####################\n",
    "        # update generator weights\n",
    "        model.module.optimizer_G.zero_grad()\n",
    "        loss_G.backward()\n",
    "        model.module.optimizer_G.step()\n",
    "\n",
    "        # update discriminator weights\n",
    "        model.module.optimizer_D.zero_grad()\n",
    "        loss_D.backward()\n",
    "        model.module.optimizer_D.step()\n",
    "        \n",
    "        #call([\"nvidia-smi\", \"--format=csv\", \"--query-gpu=memory.used,memory.free\"]) \n",
    "\n",
    "        ############## Display results and errors ##########\n",
    "        ### print out errors\n",
    "        if total_steps % opt.print_freq == print_delta:\n",
    "            errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}\n",
    "            t = (time.time() - iter_start_time) / opt.batchSize\n",
    "            visualizer.print_current_errors(epoch, epoch_iter, errors, t)\n",
    "            visualizer.plot_current_errors(errors, total_steps)\n",
    "\n",
    "        ### display output images\n",
    "        if save_fake:\n",
    "            visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),\n",
    "                                   ('synthesized_image', util.tensor2im(generated.data[0])),\n",
    "                                   ('real_image', util.tensor2im(data['image'][0]))])\n",
    "            visualizer.display_current_results(visuals, epoch, total_steps)\n",
    "\n",
    "        ### save latest model\n",
    "        if total_steps % opt.save_latest_freq == save_delta:\n",
    "            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))\n",
    "            model.module.save('latest')            \n",
    "            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')\n",
    "\n",
    "        if epoch_iter >= dataset_size:\n",
    "            break\n",
    "       \n",
    "    # end of epoch \n",
    "    iter_end_time = time.time()\n",
    "    print('End of epoch %d / %d \\t Time Taken: %d sec' %\n",
    "          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))\n",
    "\n",
    "    ### save model for this epoch\n",
    "    if epoch % opt.save_epoch_freq == 0:\n",
    "        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))        \n",
    "        model.module.save('latest')\n",
    "        model.module.save(epoch)\n",
    "        np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')\n",
    "\n",
    "    ### instead of only training the local enhancer, train the entire network after certain iterations\n",
    "    if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):\n",
    "        model.module.update_fixed_params()\n",
    "\n",
    "    ### linearly decay learning rate after certain iterations\n",
    "    if epoch > opt.niter:\n",
    "        model.module.update_learning_rate()\n",
    "        \n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3.6",
   "language": "python",
   "name": "python3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
